{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WQpNapZNWuXP"
   },
   "source": [
    "\n",
    "**Best-of-n sampling as an alternative to RLHF**\n",
    "\n",
    "This notebook compares reward-model scores of prompt based responses from \n",
    "1. a base model (`gpt2-imdb`)\n",
    "2. `RLHF` tuned model based on this base-model \n",
    "3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model\n",
    "\n",
    "Import dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vDA6qayz692w"
   },
   "outputs": [],
   "source": [
    "%pip install transformers trl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "M1s_iNm773hM"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "\n",
    "from transformers import pipeline, AutoTokenizer\n",
    "from datasets import load_dataset\n",
    "\n",
    "from trl import AutoModelForCausalLMWithValueHead\n",
    "from trl.core import LengthSampler\n",
    "\n",
    "device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
    "device = \"cpu\" if device is None else device"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Y7hyrIrO8tcY"
   },
   "source": [
    "Various constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "MqS3OM6Q8x6g"
   },
   "outputs": [],
   "source": [
    "ref_model_name = \"lvwerra/gpt2-imdb\"\n",
    "model_name = \"lvwerra/gpt2-imdb-pos-v2\"\n",
    "reward_model = \"lvwerra/distilbert-imdb\"\n",
    "\n",
    "N_BEST_OF = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "c1YcXeElg6or"
   },
   "source": [
    "Models and tokenizers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b855NrL181Hh"
   },
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)\n",
    "\n",
    "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)\n",
    "\n",
    "reward_pipe = pipeline(\"sentiment-analysis\", model=reward_model, device=device)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(ref_model_name)\n",
    "\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "# put models to accelerator\n",
    "model.to(device)\n",
    "ref_model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Z1Cz0gCFhZYJ"
   },
   "source": [
    "Dataset building"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "LqLVEp5p_8XM"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating train split: 100%|██████████| 25000/25000 [00:00<00:00, 113700.67 examples/s]\n",
      "Generating test split: 100%|██████████| 25000/25000 [00:00<00:00, 131049.39 examples/s]\n",
      "Generating unsupervised split: 100%|██████████| 50000/50000 [00:00<00:00, 126486.39 examples/s]\n",
      "Filter: 100%|██████████| 25000/25000 [00:00<00:00, 238843.61 examples/s]\n",
      "Map:   0%|          | 0/24895 [00:00<?, ? examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1168 > 1024). Running this sequence through the model will result in indexing errors\n",
      "Map: 100%|██████████| 24895/24895 [00:17<00:00, 1462.36 examples/s]\n"
     ]
    }
   ],
   "source": [
    "def build_dataset(\n",
    "    tokenizer,\n",
    "    dataset_name=\"stanfordnlp/imdb\",\n",
    "    input_min_text_length=2,\n",
    "    input_max_text_length=8,\n",
    "):\n",
    "    # load imdb with datasets\n",
    "    ds = load_dataset(dataset_name, split=\"train\")\n",
    "    ds = ds.rename_columns({\"text\": \"review\"})\n",
    "    ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n",
    "\n",
    "    input_size = LengthSampler(input_min_text_length, input_max_text_length)\n",
    "\n",
    "    def tokenize(sample):\n",
    "        sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n",
    "        sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n",
    "        return sample\n",
    "\n",
    "    ds = ds.map(tokenize, batched=False)\n",
    "    ds.set_format(type=\"torch\")\n",
    "    return ds\n",
    "\n",
    "\n",
    "dataset = build_dataset(tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "AqA2McjMAxNw"
   },
   "outputs": [],
   "source": [
    "gen_kwargs = {\n",
    "    \"min_length\": -1,\n",
    "    \"top_k\": 0.0,\n",
    "    \"top_p\": 1.0,\n",
    "    \"do_sample\": True,\n",
    "    \"pad_token_id\": tokenizer.eos_token_id,\n",
    "}\n",
    "sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "L_q4qs35AxcR"
   },
   "outputs": [],
   "source": [
    "output_min_length = 4\n",
    "output_max_length = 16\n",
    "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n",
    "\n",
    "#### get a batch from the dataset\n",
    "bs = 16\n",
    "output_data = dict()\n",
    "dataset.set_format(\"pandas\")\n",
    "df_batch = dataset[:].sample(bs)\n",
    "output_data[\"query\"] = df_batch[\"query\"].tolist()\n",
    "query_tensors = df_batch[\"input_ids\"].tolist()\n",
    "\n",
    "# :: [Resp]\n",
    "response_tensors_ref, response_tensors = [], []\n",
    "# :: [[Resp]]\n",
    "response_tensors_best_of = []"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QVfpyHnZBLKY"
   },
   "source": [
    "\n",
    "Generation using various models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-imZ7uEFBNbw"
   },
   "outputs": [],
   "source": [
    "for i in range(bs):\n",
    "    gen_len = output_length_sampler()\n",
    "\n",
    "    query = torch.tensor(query_tensors[i])\n",
    "\n",
    "    output = ref_model.generate(\n",
    "        query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
    "    ).squeeze()\n",
    "    response_tensors_ref.append(tokenizer.decode(output))\n",
    "\n",
    "    output = model.generate(\n",
    "        query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
    "    ).squeeze()\n",
    "    response_tensors.append(tokenizer.decode(output))\n",
    "\n",
    "    # generating copies of the same query for the Best-of-n sampling\n",
    "    queries = query.repeat((N_BEST_OF, 1))\n",
    "    output = ref_model.generate(\n",
    "        queries.to(device), max_new_tokens=gen_len, **gen_kwargs\n",
    "    ).squeeze()\n",
    "    response_tensors_best_of.append(tokenizer.batch_decode(output))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Jp5FC0Y5h_Sf"
   },
   "source": [
    "Scoring"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "PyDbbAQ0F_h7"
   },
   "outputs": [],
   "source": [
    "scores_ref = [\n",
    "    output[0][\"score\"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)\n",
    "]\n",
    "scores = [output[0][\"score\"] for output in reward_pipe(response_tensors, **sent_kwargs)]\n",
    "scores_best_of = []\n",
    "for i, response in enumerate(response_tensors_best_of):\n",
    "    # base_score = scores_ref[i]\n",
    "    scores_best_of.append(\n",
    "        torch.tensor(\n",
    "            [output[0][\"score\"] for output in reward_pipe(response, **sent_kwargs)]\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 682
    },
    "id": "nA1GDNJEiGm-",
    "outputId": "1389c686-0751-4304-dea2-b71fd68748e1"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>query</th>\n",
       "      <th>response (ref)</th>\n",
       "      <th>scores (ref)</th>\n",
       "      <th>response (RLHF)</th>\n",
       "      <th>scores (RLHF)</th>\n",
       "      <th>response (best_of)</th>\n",
       "      <th>scores (best_of)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>This movie is one of</td>\n",
       "      <td>This movie is one of the most twisted films I</td>\n",
       "      <td>2.094254</td>\n",
       "      <td>This movie is one of the finest directors of the</td>\n",
       "      <td>2.726879</td>\n",
       "      <td>This movie is one of the best looking movies I</td>\n",
       "      <td>2.705925</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>one may</td>\n",
       "      <td>one may feel we are seeing more</td>\n",
       "      <td>1.478813</td>\n",
       "      <td>one may not have great assets,</td>\n",
       "      <td>0.420451</td>\n",
       "      <td>one may not be supported, terrible</td>\n",
       "      <td>2.043730</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>This is an amazing film,</td>\n",
       "      <td>This is an amazing film, one of our favorite g...</td>\n",
       "      <td>2.871389</td>\n",
       "      <td>This is an amazing film, with all thelike wond...</td>\n",
       "      <td>2.918770</td>\n",
       "      <td>This is an amazing film, very moving and this ...</td>\n",
       "      <td>2.871694</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>just below</td>\n",
       "      <td>just below)and makes it seem as</td>\n",
       "      <td>0.861618</td>\n",
       "      <td>just below the world capital is a man</td>\n",
       "      <td>0.238322</td>\n",
       "      <td>just below) in this beautiful comedy.</td>\n",
       "      <td>2.760033</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Return To the</td>\n",
       "      <td>Return To the Museum. That film, called Bl</td>\n",
       "      <td>0.017376</td>\n",
       "      <td>Return To the East\" is a fascinating film,</td>\n",
       "      <td>2.648028</td>\n",
       "      <td>Return To the International: Miyazaki, by Ts</td>\n",
       "      <td>1.072344</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Brando plays the ace jet</td>\n",
       "      <td>Brando plays the ace jet fighter pilot, who stops</td>\n",
       "      <td>0.565335</td>\n",
       "      <td>Brando plays the ace jet pilot, who's a</td>\n",
       "      <td>0.668954</td>\n",
       "      <td>Brando plays the ace jet pilot Charlie; his fo...</td>\n",
       "      <td>0.679582</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>And a rather U</td>\n",
       "      <td>And a rather Utopian horror movie and with good</td>\n",
       "      <td>2.245751</td>\n",
       "      <td>And a rather Utop Congressional Movie, with a 45</td>\n",
       "      <td>0.307100</td>\n",
       "      <td>And a rather U of A complete combination of wh...</td>\n",
       "      <td>2.209265</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>The plot of this movie hangs</td>\n",
       "      <td>The plot of this movie hangs in the balance as...</td>\n",
       "      <td>1.122540</td>\n",
       "      <td>The plot of this movie hangs out well. The who...</td>\n",
       "      <td>2.195263</td>\n",
       "      <td>The plot of this movie hangs together within t...</td>\n",
       "      <td>1.310783</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>This isn't</td>\n",
       "      <td>This isn't all that bad; as for my</td>\n",
       "      <td>0.623968</td>\n",
       "      <td>This isn't a good film because I loved it</td>\n",
       "      <td>1.694601</td>\n",
       "      <td>This isn't bad writing, powerful actors and sp...</td>\n",
       "      <td>1.835901</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>This movie was for a</td>\n",
       "      <td>This movie was for a good reason!' Uh, OK</td>\n",
       "      <td>0.437566</td>\n",
       "      <td>This movie was for a fun, and grand Robinson</td>\n",
       "      <td>2.531890</td>\n",
       "      <td>This movie was for a bastard.&lt;br /&gt;&lt;br</td>\n",
       "      <td>2.311337</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>witty. funny.</td>\n",
       "      <td>witty. funny.&lt;|endoftext|&gt;</td>\n",
       "      <td>1.636344</td>\n",
       "      <td>witty. funny. funnier. more funny. funnier. fu...</td>\n",
       "      <td>2.132353</td>\n",
       "      <td>witty. funny. In the first scene the comical n...</td>\n",
       "      <td>2.164077</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>It's very hard</td>\n",
       "      <td>It's very hard to believe that anyone would en...</td>\n",
       "      <td>1.003727</td>\n",
       "      <td>It's very hard to wrap your mind around what h...</td>\n",
       "      <td>0.778888</td>\n",
       "      <td>It's very hard to wrap this up, due to lack of...</td>\n",
       "      <td>1.598843</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>Absolutely fantastic trash....this one</td>\n",
       "      <td>Absolutely fantastic trash....this one was hav...</td>\n",
       "      <td>1.350834</td>\n",
       "      <td>Absolutely fantastic trash....this one is a pe...</td>\n",
       "      <td>2.177587</td>\n",
       "      <td>Absolutely fantastic trash....this one ruins i...</td>\n",
       "      <td>2.221997</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>Prior to</td>\n",
       "      <td>Prior to this action film,</td>\n",
       "      <td>0.242474</td>\n",
       "      <td>Prior to Christian Kane's star</td>\n",
       "      <td>0.297408</td>\n",
       "      <td>Prior to his restoration, Passion</td>\n",
       "      <td>1.655534</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>i,</td>\n",
       "      <td>i, Marty Rathbun, Damon Wayans, Mark Watney and</td>\n",
       "      <td>0.105734</td>\n",
       "      <td>i, perhaps the great movie the director should...</td>\n",
       "      <td>1.336116</td>\n",
       "      <td>i, Martin was a thrill of 70s---wow!lee and Heath</td>\n",
       "      <td>2.277638</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>The film</td>\n",
       "      <td>The film takes a very grim craggy look</td>\n",
       "      <td>0.069017</td>\n",
       "      <td>The film is one of the best of that era</td>\n",
       "      <td>2.737825</td>\n",
       "      <td>The film's ambition was almost so great that its</td>\n",
       "      <td>2.357480</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                     query  \\\n",
       "0                     This movie is one of   \n",
       "1                                  one may   \n",
       "2                 This is an amazing film,   \n",
       "3                               just below   \n",
       "4                            Return To the   \n",
       "5                 Brando plays the ace jet   \n",
       "6                           And a rather U   \n",
       "7             The plot of this movie hangs   \n",
       "8                               This isn't   \n",
       "9                     This movie was for a   \n",
       "10                           witty. funny.   \n",
       "11                          It's very hard   \n",
       "12  Absolutely fantastic trash....this one   \n",
       "13                                Prior to   \n",
       "14                                      i,   \n",
       "15                                The film   \n",
       "\n",
       "                                       response (ref)  scores (ref)  \\\n",
       "0       This movie is one of the most twisted films I      2.094254   \n",
       "1                     one may feel we are seeing more      1.478813   \n",
       "2   This is an amazing film, one of our favorite g...      2.871389   \n",
       "3                     just below)and makes it seem as      0.861618   \n",
       "4          Return To the Museum. That film, called Bl      0.017376   \n",
       "5   Brando plays the ace jet fighter pilot, who stops      0.565335   \n",
       "6     And a rather Utopian horror movie and with good      2.245751   \n",
       "7   The plot of this movie hangs in the balance as...      1.122540   \n",
       "8                  This isn't all that bad; as for my      0.623968   \n",
       "9           This movie was for a good reason!' Uh, OK      0.437566   \n",
       "10                         witty. funny.<|endoftext|>      1.636344   \n",
       "11  It's very hard to believe that anyone would en...      1.003727   \n",
       "12  Absolutely fantastic trash....this one was hav...      1.350834   \n",
       "13                         Prior to this action film,      0.242474   \n",
       "14    i, Marty Rathbun, Damon Wayans, Mark Watney and      0.105734   \n",
       "15             The film takes a very grim craggy look      0.069017   \n",
       "\n",
       "                                      response (RLHF)  scores (RLHF)  \\\n",
       "0    This movie is one of the finest directors of the       2.726879   \n",
       "1                      one may not have great assets,       0.420451   \n",
       "2   This is an amazing film, with all thelike wond...       2.918770   \n",
       "3               just below the world capital is a man       0.238322   \n",
       "4          Return To the East\" is a fascinating film,       2.648028   \n",
       "5             Brando plays the ace jet pilot, who's a       0.668954   \n",
       "6    And a rather Utop Congressional Movie, with a 45       0.307100   \n",
       "7   The plot of this movie hangs out well. The who...       2.195263   \n",
       "8           This isn't a good film because I loved it       1.694601   \n",
       "9        This movie was for a fun, and grand Robinson       2.531890   \n",
       "10  witty. funny. funnier. more funny. funnier. fu...       2.132353   \n",
       "11  It's very hard to wrap your mind around what h...       0.778888   \n",
       "12  Absolutely fantastic trash....this one is a pe...       2.177587   \n",
       "13                     Prior to Christian Kane's star       0.297408   \n",
       "14  i, perhaps the great movie the director should...       1.336116   \n",
       "15            The film is one of the best of that era       2.737825   \n",
       "\n",
       "                                   response (best_of)  scores (best_of)  \n",
       "0      This movie is one of the best looking movies I          2.705925  \n",
       "1                  one may not be supported, terrible          2.043730  \n",
       "2   This is an amazing film, very moving and this ...          2.871694  \n",
       "3               just below) in this beautiful comedy.          2.760033  \n",
       "4        Return To the International: Miyazaki, by Ts          1.072344  \n",
       "5   Brando plays the ace jet pilot Charlie; his fo...          0.679582  \n",
       "6   And a rather U of A complete combination of wh...          2.209265  \n",
       "7   The plot of this movie hangs together within t...          1.310783  \n",
       "8   This isn't bad writing, powerful actors and sp...          1.835901  \n",
       "9              This movie was for a bastard.<br /><br          2.311337  \n",
       "10  witty. funny. In the first scene the comical n...          2.164077  \n",
       "11  It's very hard to wrap this up, due to lack of...          1.598843  \n",
       "12  Absolutely fantastic trash....this one ruins i...          2.221997  \n",
       "13                  Prior to his restoration, Passion          1.655534  \n",
       "14  i, Martin was a thrill of 70s---wow!lee and Heath          2.277638  \n",
       "15   The film's ambition was almost so great that its          2.357480  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output_data[\"response (ref)\"] = response_tensors_ref\n",
    "output_data[\"scores (ref)\"] = scores_ref\n",
    "output_data[\"response (RLHF)\"] = response_tensors\n",
    "output_data[\"scores (RLHF)\"] = scores\n",
    "output_data[\"response (best_of)\"] = [\n",
    "    response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)\n",
    "]\n",
    "output_data[\"scores (best_of)\"] = [a.max().item() for a in scores_best_of]\n",
    "\n",
    "\n",
    "# store results in a dataframe\n",
    "df_results = pd.DataFrame(output_data)\n",
    "df_results"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
